# --------------------------------------------#
#   该部分代码用于看网络结构
# --------------------------------------------#
from nets.ssd import SSD300

#  1类时，Total params: 6,136,170, Trainable params: 6,110,826
# 21类时，Total params: 8,855,570, Trainable params: 8,830,226
if __name__ == "__main__":
    input_shape = [300, 300, 3]
    # num_classes = 21
    num_classes = 1
    model = SSD300(input_shape, num_classes)
    model.summary()

    # for i,layer in enumerate(model.layers):
    #     print(i,layer.name)
